from typing import Tuple

import torch
import numpy as np
from torch.distributions import Normal, VonMises

import matplotlib.pyplot as plt

class MetadynamicsSampler():
    """
    A class to faciliate metadynamics exploration in a given continuous environment.
    """
    def __init__(self,
                 config,
                 env,
                 ):
        self.config = config
        self.device = torch.device(config["device"])
        self.batch_size = torch.tensor(self.config["batch_size"], device=self.device)
        self.env = env

        self.delta_t = torch.tensor(self.config["metad"]["delta_t"], device=self.device)
        self.n = torch.tensor(self.config["metad"]["n"], device=self.device)
        self.beta = torch.tensor(self.config["metad"]["beta"], device=self.device)
        self.epsilon = torch.tensor(self.config["metad"]["epsilon"], device=self.device)
        self.gamma = torch.tensor(self.config["metad"]["gamma"], device=self.device)
        self.w = torch.tensor(self.config["metad"]["w"], device=self.device)

        self.z, self.p = self._init_positions_momenta()

        self.bias_potential = torch.zeros(self.env.num_grid_points.tolist(), device=self.device)
        self.grad_bias_potential = torch.zeros([self.env.dim] + self.env.num_grid_points.tolist(), device=self.device)
        self.position_kde = torch.zeros(self.env.num_grid_points.tolist(), device=self.device)
        self.reward_kde = torch.zeros(self.env.num_grid_points.tolist(), device=self.device)
        self.confining_potential = torch.zeros(self.env.num_grid_points.tolist(), device=self.device)
        self.grad_confining_potential = torch.zeros([self.env.dim] + self.env.num_grid_points.tolist(), device=self.device)
        self.iteration_number = 0

        self.kde_widths = torch.tensor(self.config["metad"]["kde_widths"], device=self.device)

        # assert self.env.dim <= 2, "Metadynamics is currently only supported for 1D and 2D environments."
    
    def _init_positions_momenta(self) -> Tuple[np.array]:
        """
        Initializes the positions and momenta of the samples.
        """
        z = torch.zeros((self.batch_size, self.env.dim), device=self.device)
        p = torch.zeros((self.batch_size, self.env.dim), device=self.device)

        # The parameters for sampling initial position and momentum are provided in the config file
        z_mu0s = self.config["metad"]["z_mu0s"]
        z_var0s = self.config["metad"]["z_var0s"]
        p_mu0s = self.config["metad"]["p_mu0s"]
        p_var0s = self.config["metad"]["p_var0s"]

        # Initialize the positions and momenta for each dimension if the state space
        for dim in range(self.env.dim):
            if ((z_mu0s[dim] is None) or (z_mu0s[dim] == -1)) and ((z_var0s[dim] is None) or(z_var0s[dim] == -1)):
                # If the initial position is not provided, sample uniformly from the bounds
                z[:, dim] = torch.rand(self.batch_size, device=self.device) * (self.env.upper_bound[dim] - self.env.lower_bound[dim]) + self.env.lower_bound[dim]
            else:
                # Otherwise sample with a Gaussian prior
                assert z_mu0s[dim] is not None and z_var0s[dim] is not None, "Both mu and var for z should be provided for the initial position (Gaussian prior) or both should be None (uniform prior)"
                z[:, dim] = torch.normal(z_mu0s[dim] * torch.ones((self.batch_size,), device=self.device), z_var0s[dim] * torch.ones((self.batch_size,), device=self.device))

                if self.env.angle_dim[dim]:
                    # Wrap the z values if they fall outside the bounds for angle dimensions
                    z[:, dim] = torch.fmod(z[:, dim] + np.pi, 2 * np.pi) - np.pi
                    z[:, dim] = torch.fmod(z[:, dim] - np.pi, 2 * np.pi) + np.pi
                else:
                    # If any z values fall outside the bounds, raise an error for non-angle dimensions
                    assert torch.all((z[:, dim] >= self.env.lower_bound[dim]) & (z[:, dim] <= self.env.upper_bound[dim])), "Initial z values should be within the bounds for metadynamics. Try reducing the variance of the Gaussian prior."

            assert p_mu0s[dim] is not None and p_var0s[dim] is not None, "Both mu and var for p should be provided for the initial position (Gaussian prior)"
            p[:, dim] = torch.normal(p_mu0s[dim] * torch.ones((self.batch_size,), device=self.device), p_var0s[dim] * torch.ones((self.batch_size,), device=self.device))

        return z, p
    
    def _compute_log_probabilities(self):
        """
        Computes the kde log probabilities of the current samples over the state space grid.
        """

        # Compute the distances in each dimension
        # Initialize log_probs as a tensor of zeros with the correct shape
        log_probs = torch.zeros((self.z.size(0),) + tuple(len(self.env.marginal_grid[dim]) for dim in range(self.env.dim)), device=self.device)
        for dim in range(self.env.dim):
            if self.env.angle_dim[dim]:
                concentration = 1 / self.kde_widths[dim]**2
                vonmises_dist = VonMises(self.z[:, dim].reshape(-1, 1), concentration)
                marginal_log_probs = vonmises_dist.log_prob(self.env.marginal_grid[dim])       # batchsize x num_grid_points[dim]
            else:
                normal_dist = Normal(self.z[:, dim].reshape(-1, 1), self.kde_widths[dim])
                marginal_log_probs = normal_dist.log_prob(self.env.marginal_grid[dim])         # batchsize x num_grid_points[dim]

            # Reshape marginal_log_probs to enable broadcasting
            # Start with the shape [batch_size, 1, 1, ..., num_grid_points[dim], ..., 1]
            # where num_grid_points[dim] is inserted at the correct position
            shape = [self.z.size(0)] + [1] * dim + [marginal_log_probs.size(1)] + [1] * (self.env.dim - dim - 1)
            reshaped_marginal_log_probs = marginal_log_probs.view(*shape)

            # Accumulate log_probs by adding along the respective dimension
            log_probs += reshaped_marginal_log_probs

        return log_probs
    
    def _update_potentials(self, batch_reward):
        """
        Updates the bias and confining potentials and their derivatives.
        """

        log_probs = self._compute_log_probabilities()
        boost = torch.exp(log_probs).sum(dim=0)
        self.bias_potential += self.n * self.delta_t * self.w * boost
        self.grad_bias_potential = torch.stack(torch.gradient(self.bias_potential, spacing=self.env.grid_spacing.tolist())) 
        self.position_kde += boost
        self.reward_kde += (batch_reward[(...,) + (None,) * self.env.dim] * torch.exp(log_probs)).sum(dim=0)
        self.confining_potential = - (1/self.beta) * torch.log(self.reward_kde / (self.position_kde + self.epsilon) + self.epsilon)
        self.grad_confining_potential = torch.stack(torch.gradient(self.confining_potential, spacing=self.env.grid_spacing.tolist())) 

    def sample(self):
        """
        Applys a step of metadynamics and returns the new sample.
        """

        # Apply n steps of metadynamics simulation to the current sample
        for i in range(self.n):
            for dim in range(self.env.dim):
                if self.env.angle_dim[dim]:
                    # If dimension is an angular, wrap coordinates if they stray outside of grid boundaries
                    self.z[:, dim] = torch.fmod(self.z[:, dim] + np.pi, 2 * np.pi) - np.pi
                    self.z[:, dim] = torch.fmod(self.z[:, dim] - np.pi, 2 * np.pi) + np.pi
                else:
                    # If dimension is not angular, reflect samples if they stray outside of grid boundaries
                    z_clamped = torch.clamp(self.z[:, dim], self.env.lower_bound[dim], self.env.upper_bound[dim])
                    self.z[:, dim] = self.z[:, dim] + 2 * (z_clamped - self.z[:, dim])

            # Compute the grid indices of the current samples
            idx = torch.div((self.z - self.env.lower_bound), self.env.grid_spacing).floor().long() - 1

            # Compute the force on the samples
            indices = tuple([torch.arange(self.grad_bias_potential.size(0))[:, None]] + [idx[:, i] for i in range(self.env.dim)])

            F = -(self.grad_bias_potential[indices] + self.grad_confining_potential[indices]).clamp(min=-10, max=10).T  

            # Update the positions and momenta using Langevin dynamics (Euler-Maruyama scheme)
            self.z += self.delta_t * self.p
            self.p += self.delta_t * F - self.gamma * self.p * self.delta_t + torch.sqrt(2 * self.gamma / self.beta) * torch.randn_like(self.p, device=self.device) * torch.sqrt(self.delta_t)

        # Compute the rewards of the new samples
        log_rewards = self.env.log_reward(self.z).squeeze()
        batch_reward = torch.exp(log_rewards)
        self._update_potentials(batch_reward)
        self.iteration_number += 1

        return self.z, log_rewards
